package dev.langchain4j.guardrail;

import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.response.ChatResponse;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * The result of the validation of an {@link OutputGuardrail}
 */
public final class OutputGuardrailResult implements GuardrailResult<OutputGuardrailResult> {
    private static final OutputGuardrailResult SUCCESS = new OutputGuardrailResult();

    private final Result result;
    private final String successfulText;
    private final Object successfulResult;
    private final List<Failure> failures;

    private OutputGuardrailResult(
            Result result, String successfulText, Object successfulResult, List<Failure> failures) {
        this.result = ensureNotNull(result, "result");
        this.successfulText = successfulText;
        this.successfulResult = successfulResult;
        this.failures = Optional.ofNullable(failures).orElseGet(List::of);
    }

    private OutputGuardrailResult() {
        this(Result.SUCCESS, null, null, Collections.emptyList());
    }

    private OutputGuardrailResult(String successfulText) {
        this(Result.SUCCESS_WITH_RESULT, successfulText, null, Collections.emptyList());
    }

    private OutputGuardrailResult(String successfulText, Object successfulResult) {
        this(Result.SUCCESS_WITH_RESULT, successfulText, successfulResult, Collections.emptyList());
    }

    OutputGuardrailResult(List<Failure> failures, boolean fatal) {
        this(fatal ? Result.FATAL : Result.FAILURE, null, null, failures);
    }

    OutputGuardrailResult(Failure failure, boolean fatal) {
        // Using Stream.of().collect() here because we need a mutable list
        this(Stream.of(failure).collect(Collectors.toList()), fatal);
    }

    /**
     * Gets a successful output guardrail result
     */
    public static OutputGuardrailResult success() {
        return SUCCESS;
    }

    /**
     * Produces a successful result with specific success text
     *
     * @return The result of a successful output guardrail validation with a specific text.
     *
     * @param successfulText
     *            The text of the successful result.
     */
    public static OutputGuardrailResult successWith(String successfulText) {
        return (successfulText == null) ? success() : new OutputGuardrailResult(successfulText);
    }

    /**
     * Produces a non-fatal failure
     *
     * @param successfulText
     *            The text of the successful result.
     * @param successfulResult
     *            The object generated by this successful result.
     * @return The result of a successful output guardrail validation with a specific text.
     */
    public static OutputGuardrailResult successWith(String successfulText, Object successfulResult) {
        return new OutputGuardrailResult(successfulText, successfulResult);
    }

    /**
     * Produces a non-fatal failure
     *
     * @param failures A list of {@link Failure}s
     *
     * @return The result of a failed output guardrail validation.
     */
    public static OutputGuardrailResult failure(List<Failure> failures) {
        return new OutputGuardrailResult(failures, false);
    }

    /**
     * Whether or not the guardrail is forcing a retry
     */
    public boolean isRetry() {
        return !isSuccess() && this.failures.stream().anyMatch(Failure::retry);
    }

    /**
     * Whether or not the guardrail is forcing a reprompt
     */
    public boolean isReprompt() {
        return !isSuccess()
                && this.failures.stream()
                                .map(Failure::reprompt)
                                .filter(Objects::nonNull)
                                .count()
                        > 0;
    }

    /**
     * Block all retries for this result
     */
    public OutputGuardrailResult blockRetry() {
        this.failures.set(0, this.failures.get(0).blockRetry());
        return this;
    }

    /**
     * Gets the reprompt message
     */
    public Optional<String> getReprompt() {
        return !isSuccess()
                ? this.failures.stream()
                        .map(Failure::reprompt)
                        .filter(Objects::nonNull)
                        .findFirst()
                : Optional.empty();
    }

    @Override
    public String toString() {
        return asString();
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;
        OutputGuardrailResult that = (OutputGuardrailResult) o;
        return result == that.result
                && Objects.equals(successfulText, that.successfulText)
                && Objects.equals(successfulResult, that.successfulResult)
                && Objects.equals(failures, that.failures);
    }

    @Override
    public int hashCode() {
        return Objects.hash(result, successfulText, successfulResult, failures);
    }

    /**
     * Gets the response computed from the combination of the original {@link ChatResponse} in the {@link OutputGuardrailRequest}
     * and this result
     * @param request The output guardrail request
     * @param <T> The type of response
     * @return A response computed from the combination of the original {@link ChatResponse} in the {@link OutputGuardrailRequest}
     * and this result
     */
    public <T> T response(OutputGuardrailRequest request) {
        return (T) Optional.ofNullable(successfulResult).orElseGet(() -> createResponse(request));
    }

    private ChatResponse createResponse(OutputGuardrailRequest params) {
        var response = params.responseFromLLM();
        var aiMessage = response.aiMessage();
        var newAiMessage = aiMessage;

        if (hasRewrittenResult()) {
            newAiMessage = aiMessage.hasToolExecutionRequests()
                    ? AiMessage.from(successfulText(), aiMessage.toolExecutionRequests())
                    : AiMessage.from(successfulText());
        }

        return response.toBuilder().aiMessage(newAiMessage).build();
    }

    @Override
    public Result result() {
        return result;
    }

    @Override
    @SuppressWarnings("unchecked")
    public <F extends GuardrailResult.Failure> List<F> failures() {
        return (List<F>) failures;
    }

    @Override
    public String successfulText() {
        return successfulText;
    }

    public Object successfulResult() {
        return successfulResult;
    }

    /**
     * Represents an output guardrail failure
     */
    public static final class Failure implements GuardrailResult.Failure {
        private final String message;
        private final Throwable cause;
        private final Class<? extends Guardrail> guardrailClass;
        private final boolean retry;
        private final String reprompt;

        Failure(
                String message,
                Throwable cause,
                Class<? extends Guardrail> guardrailClass,
                boolean retry,
                String reprompt) {
            this.message = ensureNotNull(message, "message");
            this.cause = cause;
            this.guardrailClass = guardrailClass;
            this.retry = retry;
            this.reprompt = reprompt;
        }

        Failure(String message) {
            this(message, null);
        }

        Failure(String message, Throwable cause) {
            this(message, cause, false);
        }

        Failure(String message, Throwable cause, boolean retry) {
            this(message, cause, null, retry, null);
        }

        Failure(String message, Throwable cause, boolean retry, String reprompt) {
            this(message, cause, null, retry, reprompt);
        }

        @Override
        public Failure withGuardrailClass(Class<? extends Guardrail> guardrailClass) {
            ensureNotNull(guardrailClass, "guardrailClass");
            return new Failure(message(), cause(), guardrailClass, this.retry, this.reprompt);
        }

        @Override
        public String message() {
            return message;
        }

        @Override
        public Throwable cause() {
            return cause;
        }

        @Override
        public Class<? extends Guardrail> guardrailClass() {
            return guardrailClass;
        }

        /**
         * Create a failure from this failure that blocks retries
         */
        public Failure blockRetry() {
            return this.retry
                    ? new Failure(
                            "Retry or reprompt is not allowed after a rewritten output",
                            cause(),
                            this.guardrailClass,
                            false,
                            this.reprompt)
                    : this;
        }

        @Override
        public String toString() {
            return asString();
        }

        public boolean retry() {
            return retry;
        }

        public String reprompt() {
            return reprompt;
        }
    }
}
